{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/dauparas/ProteinMPNN/blob/main/colab_notebooks/quickdemo_wAF2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AYZebfKn8gef"
      },
      "source": [
        "#ProteinMPNN w/AF2\n",
        "This notebook is intended as a quick demo, more features to come!\n",
        "\n",
        "Examples: \n",
        "1.   pdb: `6MRR`, homomer: `False`, designed_chain: `A`\n",
        "2.   pdb: `1X2I`, homomer: `True`, designed_chain: `A,B` \n",
        "     (for correct symmetric tying lenghts of homomer chains should be the same)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Setup ProteinMPNN\n",
        "import warnings\n",
        "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
        "\n",
        "import json, time, os, sys, glob, re\n",
        "from google.colab import files\n",
        "import numpy as np\n",
        "\n",
        "if not os.path.isdir(\"ProteinMPNN\"):\n",
        "  os.system(\"git clone -q https://github.com/dauparas/ProteinMPNN.git\")\n",
        "\n",
        "if \"ProteinMPNN\" not in sys.path:\n",
        "  sys.path.append('/content/ProteinMPNN')\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import shutil\n",
        "import warnings\n",
        "import torch\n",
        "from torch import optim\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.utils.data.dataset import random_split, Subset\n",
        "import copy\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import random\n",
        "import os.path\n",
        "from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB\n",
        "from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN\n",
        "\n",
        "device = torch.device(\"cpu\")\n",
        "#v_48_010=version with 48 edges 0.10A noise\n",
        "model_name = \"v_48_020\" #@param [\"v_48_002\", \"v_48_010\", \"v_48_020\", \"v_48_030\"]\n",
        "\n",
        "\n",
        "backbone_noise=0.00               # Standard deviation of Gaussian noise to add to backbone atoms\n",
        "\n",
        "path_to_model_weights='/content/ProteinMPNN/vanilla_model_weights'          \n",
        "hidden_dim = 128\n",
        "num_layers = 3 \n",
        "model_folder_path = path_to_model_weights\n",
        "if model_folder_path[-1] != '/':\n",
        "    model_folder_path = model_folder_path + '/'\n",
        "checkpoint_path = model_folder_path + f'{model_name}.pt'\n",
        "\n",
        "checkpoint = torch.load(checkpoint_path, map_location=device) \n",
        "print('Number of edges:', checkpoint['num_edges'])\n",
        "noise_level_print = checkpoint['noise_level']\n",
        "print(f'Training noise level: {noise_level_print}A')\n",
        "model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])\n",
        "model.to(device)\n",
        "model.load_state_dict(checkpoint['model_state_dict'])\n",
        "model.eval()\n",
        "print(\"Model loaded\")\n",
        "\n",
        "def make_tied_positions_for_homomers(pdb_dict_list):\n",
        "    my_dict = {}\n",
        "    for result in pdb_dict_list:\n",
        "        all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ...\n",
        "        tied_positions_list = []\n",
        "        chain_length = len(result[f\"seq_chain_{all_chain_list[0]}\"])\n",
        "        for i in range(1,chain_length+1):\n",
        "            temp_dict = {}\n",
        "            for j, chain in enumerate(all_chain_list):\n",
        "                temp_dict[chain] = [i] #needs to be a list\n",
        "            tied_positions_list.append(temp_dict)\n",
        "        my_dict[result['name']] = tied_positions_list\n",
        "    return my_dict\n",
        "\n",
        "#########################\n",
        "def get_pdb(pdb_code=\"\"):\n",
        "  if pdb_code is None or pdb_code == \"\":\n",
        "    upload_dict = files.upload()\n",
        "    pdb_string = upload_dict[list(upload_dict.keys())[0]]\n",
        "    with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n",
        "    return \"tmp.pdb\"\n",
        "  else:\n",
        "    os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n",
        "    return f\"{pdb_code}.pdb\""
      ],
      "metadata": {
        "id": "2nKSlaMlSpcf",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "xMVlYh8Fv2of"
      },
      "outputs": [],
      "source": [
        "#@title #Run ProteinMPNN\n",
        "\n",
        "#@markdown #### Input Options\n",
        "pdb='6MRR' #@param {type:\"string\"}\n",
        "pdb = pdb.replace(\" \",\"\")\n",
        "pdb_path = get_pdb(pdb)\n",
        "#@markdown - pdb code (leave blank to get an upload prompt)\n",
        "\n",
        "homomer = False #@param {type:\"boolean\"}\n",
        "designed_chain = \"A\" #@param {type:\"string\"}\n",
        "fixed_chain = \"\" #@param {type:\"string\"}\n",
        "\n",
        "if designed_chain == \"\":\n",
        "  designed_chain_list = []\n",
        "else:\n",
        "  designed_chain_list = re.sub(\"[^A-Za-z]+\",\",\", designed_chain).split(\",\")\n",
        "\n",
        "if fixed_chain == \"\":\n",
        "  fixed_chain_list = []\n",
        "else:\n",
        "  fixed_chain_list = re.sub(\"[^A-Za-z]+\",\",\", fixed_chain).split(\",\")\n",
        "\n",
        "chain_list = list(set(designed_chain_list + fixed_chain_list))\n",
        "\n",
        "#@markdown - specified which chain(s) to design and which chain(s) to keep fixed. \n",
        "#@markdown   Use comma:`A,B` to specifiy more than one chain\n",
        "\n",
        "#chain = \"A\" #@param {type:\"string\"}\n",
        "#pdb_path_chains = chain\n",
        "##@markdown - Define which chain to redesign\n",
        "\n",
        "#@markdown #### Design Options\n",
        "num_seqs = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n",
        "num_seq_per_target = num_seqs\n",
        "\n",
        "#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.\n",
        "sampling_temp = \"0.1\" #@param [\"0.0001\", \"0.1\", \"0.15\", \"0.2\", \"0.25\", \"0.3\", \"0.5\"]\n",
        "\n",
        "\n",
        "\n",
        "save_score=0                      # 0 for False, 1 for True; save score=-log_prob to npy files\n",
        "save_probs=0                      # 0 for False, 1 for True; save MPNN predicted probabilites per position\n",
        "score_only=0                      # 0 for False, 1 for True; score input backbone-sequence pairs\n",
        "conditional_probs_only=0          # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone)\n",
        "conditional_probs_only_backbone=0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone)\n",
        "    \n",
        "batch_size=1                      # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory\n",
        "max_length=20000                  # Max sequence length\n",
        "    \n",
        "out_folder='.'                    # Path to a folder to output sequences, e.g. /home/out/\n",
        "jsonl_path=''                     # Path to a folder with parsed pdb into jsonl\n",
        "omit_AAs='X'                      # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine.\n",
        "   \n",
        "pssm_multi=0.0                    # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions\n",
        "pssm_threshold=0.0                # A value between -inf + inf to restric per position AAs\n",
        "pssm_log_odds_flag=0               # 0 for False, 1 for True\n",
        "pssm_bias_flag=0                   # 0 for False, 1 for True\n",
        "\n",
        "\n",
        "##############################################################\n",
        "\n",
        "folder_for_outputs = out_folder\n",
        "\n",
        "NUM_BATCHES = num_seq_per_target//batch_size\n",
        "BATCH_COPIES = batch_size\n",
        "temperatures = [float(item) for item in sampling_temp.split()]\n",
        "omit_AAs_list = omit_AAs\n",
        "alphabet = 'ACDEFGHIKLMNPQRSTVWYX'\n",
        "\n",
        "omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)\n",
        "\n",
        "chain_id_dict = None\n",
        "fixed_positions_dict = None\n",
        "pssm_dict = None\n",
        "omit_AA_dict = None\n",
        "bias_AA_dict = None\n",
        "tied_positions_dict = None\n",
        "bias_by_res_dict = None\n",
        "bias_AAs_np = np.zeros(len(alphabet))\n",
        "\n",
        "\n",
        "###############################################################\n",
        "pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)\n",
        "dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)\n",
        "\n",
        "chain_id_dict = {}\n",
        "chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)\n",
        "\n",
        "print(chain_id_dict)\n",
        "for chain in chain_list:\n",
        "  l = len(pdb_dict_list[0][f\"seq_chain_{chain}\"])\n",
        "  print(f\"Length of chain {chain} is {l}\")\n",
        "\n",
        "if homomer:\n",
        "  tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list)\n",
        "else:\n",
        "  tied_positions_dict = None\n",
        "\n",
        "#################################################################\n",
        "sequences = []\n",
        "with torch.no_grad():\n",
        "  print('Generating sequences...')\n",
        "  for ix, protein in enumerate(dataset_valid):\n",
        "    score_list = []\n",
        "    all_probs_list = []\n",
        "    all_log_probs_list = []\n",
        "    S_sample_list = []\n",
        "    batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]\n",
        "    X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict)\n",
        "    pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false\n",
        "    name_ = batch_clones[0]['name']\n",
        "\n",
        "    randn_1 = torch.randn(chain_M.shape, device=X.device)\n",
        "    log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)\n",
        "    mask_for_loss = mask*chain_M*chain_M_pos\n",
        "    scores = _scores(S, log_probs, mask_for_loss)\n",
        "    native_score = scores.cpu().data.numpy()\n",
        "\n",
        "    for temp in temperatures:\n",
        "        for j in range(NUM_BATCHES):\n",
        "            randn_2 = torch.randn(chain_M.shape, device=X.device)\n",
        "            if tied_positions_dict == None:\n",
        "                sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)\n",
        "                S_sample = sample_dict[\"S\"] \n",
        "            else:\n",
        "                sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all)\n",
        "            # Compute scores\n",
        "                S_sample = sample_dict[\"S\"]\n",
        "            log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict[\"decoding_order\"])\n",
        "            mask_for_loss = mask*chain_M*chain_M_pos\n",
        "            scores = _scores(S_sample, log_probs, mask_for_loss)\n",
        "            scores = scores.cpu().data.numpy()\n",
        "            all_probs_list.append(sample_dict[\"probs\"].cpu().data.numpy())\n",
        "            all_log_probs_list.append(log_probs.cpu().data.numpy())\n",
        "            S_sample_list.append(S_sample.cpu().data.numpy())\n",
        "            for b_ix in range(BATCH_COPIES):\n",
        "                masked_chain_length_list = masked_chain_length_list_list[b_ix]\n",
        "                masked_list = masked_list_list[b_ix]\n",
        "                seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])\n",
        "                seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])\n",
        "                score = scores[b_ix]\n",
        "                score_list.append(score)\n",
        "                native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])\n",
        "                if b_ix == 0 and j==0 and temp==temperatures[0]:\n",
        "                    start = 0\n",
        "                    end = 0\n",
        "                    list_of_AAs = []\n",
        "                    for mask_l in masked_chain_length_list:\n",
        "                        end += mask_l\n",
        "                        list_of_AAs.append(native_seq[start:end])\n",
        "                        start = end\n",
        "                    native_seq = \"\".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))\n",
        "                    l0 = 0\n",
        "                    for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:\n",
        "                        l0 += mc_length\n",
        "                        native_seq = native_seq[:l0] + '/' + native_seq[l0:]\n",
        "                        l0 += 1\n",
        "                    sorted_masked_chain_letters = np.argsort(masked_list_list[0])\n",
        "                    print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]\n",
        "                    sorted_visible_chain_letters = np.argsort(visible_list_list[0])\n",
        "                    print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]\n",
        "                    native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)\n",
        "                    line = '>{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\\n{}\\n'.format(name_, native_score_print, print_visible_chains, print_masked_chains, model_name, native_seq)\n",
        "                    print(line.rstrip())\n",
        "                start = 0\n",
        "                end = 0\n",
        "                list_of_AAs = []\n",
        "                for mask_l in masked_chain_length_list:\n",
        "                    end += mask_l\n",
        "                    list_of_AAs.append(seq[start:end])\n",
        "                    start = end\n",
        "\n",
        "                seq = \"\".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))\n",
        "                l0 = 0\n",
        "                for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:\n",
        "                    l0 += mc_length\n",
        "                    seq = seq[:l0] + '/' + seq[l0:]\n",
        "                    l0 += 1\n",
        "                score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)\n",
        "                seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)\n",
        "                line = '>T={}, sample={}, score={}, seq_recovery={}\\n{}\\n'.format(temp,b_ix,score_print,seq_rec_print,seq)\n",
        "                sequences.append(seq)\n",
        "                print(line.rstrip())\n",
        "\n",
        "\n",
        "all_probs_concat = np.concatenate(all_probs_list)\n",
        "all_log_probs_concat = np.concatenate(all_log_probs_list)\n",
        "S_sample_concat = np.concatenate(S_sample_list)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Predict with AlphaFold2 (with single-sequence input)"
      ],
      "metadata": {
        "id": "5mQ4VLG1dPsd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Setup AlphaFold\n",
        "\n",
        "# import libraries\n",
        "from IPython.utils import io\n",
        "import os,sys,re\n",
        "\n",
        "if \"af_backprop\" not in sys.path:\n",
        "  import tensorflow as tf\n",
        "  import jax\n",
        "  import jax.numpy as jnp\n",
        "  import numpy as np\n",
        "  import matplotlib\n",
        "  from matplotlib import animation\n",
        "  import matplotlib.pyplot as plt\n",
        "  from IPython.display import HTML\n",
        "  import tqdm.notebook\n",
        "  TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
        "\n",
        "  with io.capture_output() as captured:\n",
        "    # install ALPHAFOLD\n",
        "    if not os.path.isdir(\"af_backprop\"):\n",
        "      %shell git clone https://github.com/sokrypton/af_backprop.git\n",
        "      %shell pip -q install biopython dm-haiku ml-collections py3Dmol\n",
        "      %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py\n",
        "    if not os.path.isdir(\"params\"):\n",
        "      %shell mkdir params\n",
        "      %shell curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n",
        "\n",
        "  if not os.path.exists(\"MMalign\"):\n",
        "    # install MMalign\n",
        "    os.system(\"wget -qnc https://zhanggroup.org/MM-align/bin/module/MMalign.cpp\")\n",
        "    os.system(\"g++ -static -O3 -ffast-math -o MMalign MMalign.cpp\")\n",
        "\n",
        "  def mmalign(pdb_a,pdb_b):\n",
        "    # pass to MMalign\n",
        "    output = os.popen(f'./MMalign {pdb_a} {pdb_b}')\n",
        "    # parse outputs\n",
        "    parse_float = lambda x: float(x.split(\"=\")[1].split()[0])\n",
        "    tms = []\n",
        "    for line in output:\n",
        "      line = line.rstrip()\n",
        "      if line.startswith(\"TM-score\"): tms.append(parse_float(line))\n",
        "    return tms\n",
        "\n",
        "  # configure which device to use\n",
        "  try:\n",
        "    # check if TPU is available\n",
        "    import jax.tools.colab_tpu\n",
        "    jax.tools.colab_tpu.setup_tpu()\n",
        "    print('Running on TPU')\n",
        "    DEVICE = \"tpu\"\n",
        "  except:\n",
        "    if jax.local_devices()[0].platform == 'cpu':\n",
        "      print(\"WARNING: no GPU detected, will be using CPU\")\n",
        "      DEVICE = \"cpu\"\n",
        "    else:\n",
        "      print('Running on GPU')\n",
        "      DEVICE = \"gpu\"\n",
        "      # disable GPU on tensorflow\n",
        "      tf.config.set_visible_devices([], 'GPU')\n",
        "\n",
        "  # import libraries\n",
        "  sys.path.append('af_backprop')\n",
        "  from utils import update_seq, update_aatype, get_plddt, get_pae\n",
        "  import colabfold as cf\n",
        "  from alphafold.common import protein as alphafold_protein\n",
        "  from alphafold.data import pipeline\n",
        "  from alphafold.model import data, config\n",
        "  from alphafold.common import residue_constants\n",
        "  from alphafold.model import model as alphafold_model\n",
        "\n",
        "# custom functions\n",
        "def clear_mem():\n",
        "  backend = jax.lib.xla_bridge.get_backend()\n",
        "  for buf in backend.live_buffers(): buf.delete()\n",
        "\n",
        "def setup_model(max_len):\n",
        "  clear_mem()\n",
        "\n",
        "  # setup model\n",
        "  cfg = config.model_config(\"model_3_ptm\")\n",
        "  cfg.model.num_recycle = 0\n",
        "  cfg.data.common.num_recycle = 0\n",
        "  cfg.data.eval.max_msa_clusters = 1\n",
        "  cfg.data.common.max_extra_msa = 1\n",
        "  cfg.data.eval.masked_msa_replace_fraction = 0\n",
        "  cfg.model.global_config.subbatch_size = None\n",
        "\n",
        "  # get params\n",
        "  model_param = data.get_model_haiku_params(model_name=\"model_3_ptm\", data_dir=\".\")\n",
        "  model_runner = alphafold_model.RunModel(cfg, model_param, is_training=False, recycle_mode=\"none\")\n",
        "\n",
        "  model_params = []\n",
        "  for k in [1,2,3,4,5]:\n",
        "    if k == 3:\n",
        "      model_params.append(model_param)\n",
        "    else:\n",
        "      params = data.get_model_haiku_params(model_name=f\"model_{k}_ptm\", data_dir=\".\")\n",
        "      model_params.append({k: params[k] for k in model_runner.params.keys()})\n",
        "\n",
        "  seq = \"A\" * max_len\n",
        "  length = len(seq)\n",
        "  feature_dict = {\n",
        "      **pipeline.make_sequence_features(sequence=seq, description=\"none\", num_res=length),\n",
        "      **pipeline.make_msa_features(msas=[[seq]], deletion_matrices=[[[0]*length]])\n",
        "  }\n",
        "  inputs = model_runner.process_features(feature_dict,random_seed=0)\n",
        "\n",
        "  def runner(I, params):\n",
        "    # update sequence\n",
        "    inputs = I[\"inputs\"]\n",
        "    inputs.update(I[\"prev\"])\n",
        "\n",
        "    seq = jax.nn.one_hot(I[\"seq\"],20)\n",
        "    update_seq(seq, inputs)\n",
        "    update_aatype(inputs[\"target_feat\"][...,1:], inputs)\n",
        "\n",
        "    # mask prediction\n",
        "    mask = jnp.arange(inputs[\"residue_index\"].shape[0]) < I[\"length\"]\n",
        "    inputs[\"seq_mask\"] = inputs[\"seq_mask\"].at[:].set(mask)\n",
        "    inputs[\"msa_mask\"] = inputs[\"msa_mask\"].at[:].set(mask)\n",
        "    inputs[\"residue_index\"] = jnp.where(mask, inputs[\"residue_index\"], 0)\n",
        "\n",
        "    # get prediction\n",
        "    key = jax.random.PRNGKey(0)\n",
        "    outputs = model_runner.apply(params, key, inputs)\n",
        "\n",
        "    prev = {\"init_msa_first_row\":outputs['representations']['msa_first_row'][None],\n",
        "            \"init_pair\":outputs['representations']['pair'][None],\n",
        "            \"init_pos\":outputs['structure_module']['final_atom_positions'][None]}\n",
        "    \n",
        "    aux = {\"final_atom_positions\":outputs[\"structure_module\"][\"final_atom_positions\"],\n",
        "           \"final_atom_mask\":outputs[\"structure_module\"][\"final_atom_mask\"],\n",
        "           \"plddt\":get_plddt(outputs),\"pae\":get_pae(outputs),\n",
        "           \"length\":I[\"length\"], \"seq\":I[\"seq\"], \"prev\":prev,\n",
        "           \"residue_idx\":inputs[\"residue_index\"][0]}\n",
        "    return aux\n",
        "\n",
        "  return jax.jit(runner), model_params, {\"inputs\":inputs, \"length\":max_length}\n",
        "\n",
        "def save_pdb(outs, filename, Ls=None):\n",
        "  '''save pdb coordinates'''\n",
        "  p = {\"residue_index\":outs[\"residue_idx\"] + 1,\n",
        "       \"aatype\":outs[\"seq\"],\n",
        "       \"atom_positions\":outs[\"final_atom_positions\"],\n",
        "       \"atom_mask\":outs[\"final_atom_mask\"],\n",
        "       \"plddt\":outs[\"plddt\"]}\n",
        "  p = jax.tree_map(lambda x:x[:outs[\"length\"]], p)\n",
        "  b_factors = 100 * p.pop(\"plddt\")[:,None] * p[\"atom_mask\"]\n",
        "  p = alphafold_protein.Protein(**p,b_factors=b_factors)\n",
        "  pdb_lines = alphafold_protein.to_pdb(p)\n",
        "  with open(filename, 'w') as f:\n",
        "    f.write(pdb_lines)\n",
        "  if Ls is not None:\n",
        "    pdb_lines = cf.read_pdb_renum(filename, Ls)\n",
        "    with open(filename, 'w') as f:\n",
        "      f.write(pdb_lines)"
      ],
      "metadata": {
        "cellView": "form",
        "id": "4ZBUThXU7yY8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Run AlphaFold\n",
        "num_models = 1 #@param [\"1\",\"2\",\"3\",\"4\",\"5\"] {type:\"raw\"}\n",
        "num_recycles = 1 #@param [\"0\",\"1\",\"2\",\"3\"] {type:\"raw\"}\n",
        "num_sequences = len(sequences)\n",
        "outs = []\n",
        "positions = []\n",
        "plddts = []\n",
        "paes = []\n",
        "LS = []\n",
        "\n",
        "with tqdm.notebook.tqdm(total=(num_recycles + 1) * num_models * num_sequences, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
        "  print(f\"seq_num model_num avg_pLDDT avg_pAE TMscore\")\n",
        "  for s,ori_sequence in enumerate(sequences):\n",
        "    Ls = [len(s) for s in ori_sequence.replace(\":\",\"/\").split(\"/\")]\n",
        "    LS.append(Ls)\n",
        "    sequence = re.sub(\"[^A-Z]\",\"\",ori_sequence)\n",
        "    length = len(sequence)\n",
        "\n",
        "    # avoid recompiling if length within 25\n",
        "    if \"max_len\" not in dir() or length > max_len or (max_len - length) > 25:\n",
        "      max_len = length + 25\n",
        "      runner, params, I = setup_model(max_len)\n",
        "\n",
        "    outs.append([])\n",
        "    positions.append([])\n",
        "    plddts.append([])\n",
        "    paes.append([])\n",
        "\n",
        "    r = -1\n",
        "    # pad sequence to max length\n",
        "    seq = np.array([residue_constants.restype_order.get(aa,0) for aa in sequence])\n",
        "    seq = np.pad(seq,[0,max_len-length],constant_values=-1)\n",
        "    I[\"inputs\"]['residue_index'][:] = cf.chain_break(np.arange(max_len), Ls, length=32)\n",
        "    I.update({\"seq\":seq, \"length\":length})\n",
        "    \n",
        "    # for each model\n",
        "    for n in range(num_models):\n",
        "      # restart recycle\n",
        "      I[\"prev\"] = {'init_msa_first_row': np.zeros([1, max_len, 256]),\n",
        "                  'init_pair': np.zeros([1, max_len, max_len, 128]),\n",
        "                  'init_pos': np.zeros([1, max_len, 37, 3])}\n",
        "      for r in range(num_recycles + 1):\n",
        "        O = runner(I, params[n])\n",
        "        O = jax.tree_map(lambda x:np.asarray(x), O)\n",
        "        I[\"prev\"] = O[\"prev\"]\n",
        "        pbar.update(1)\n",
        "      \n",
        "      positions[-1].append(O[\"final_atom_positions\"][:length])\n",
        "      plddts[-1].append(O[\"plddt\"][:length])\n",
        "      paes[-1].append(O[\"pae\"][:length,:length])\n",
        "      outs[-1].append(O)\n",
        "      save_pdb(outs[-1][-1], f\"out_seq_{s}_model_{n}.pdb\", Ls=LS[-1])\n",
        "      tmscores = mmalign(pdb_path, f\"out_seq_{s}_model_{n}.pdb\")\n",
        "      print(f\"{s} {n}\\t{plddts[-1][-1].mean():.3}\\t{paes[-1][-1].mean():.3}\\t{tmscores[-1]:.3}\")"
      ],
      "metadata": {
        "cellView": "form",
        "id": "p2uNokqudTSH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Display 3D structure {run: \"auto\"}\n",
        "#@markdown #### select which sequence to show (if more than one designed example)\n",
        "seq_num = 0 #@param [\"0\",\"1\",\"2\",\"3\",\"4\",\"5\",\"6\",\"7\"] {type:\"raw\"}\n",
        "assert seq_num < len(outs), f\"ERROR: seq_num ({seq_num}) exceeds number of designed sequences ({num_sequences})\"\n",
        "model_num = 0 #@param [\"0\",\"1\",\"2\",\"3\",\"4\"] {type:\"raw\"}\n",
        "assert model_num < len(outs[0]), f\"ERROR: model_num ({num_models}) exceeds number of model params used ({num_models})\"\n",
        "#@markdown #### options\n",
        "\n",
        "color = \"confidence\" #@param [\"chain\", \"confidence\", \"rainbow\"]\n",
        "if color == \"confidence\": color = \"lDDT\"\n",
        "show_sidechains = False #@param {type:\"boolean\"}\n",
        "show_mainchains = False #@param {type:\"boolean\"}\n",
        "\n",
        "v = cf.show_pdb(f\"out_seq_{seq_num}_model_{model_num}.pdb\", show_sidechains, show_mainchains, color,\n",
        "                color_HP=True, size=(800,480), Ls=LS[seq_num])       \n",
        "v.setHoverable({}, True,\n",
        "               '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel(\"      \"+atom.resn+\":\"+atom.resi,{position:atom,backgroundColor:'mintcream',fontColor:'black'});}}''',\n",
        "               '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')\n",
        "v.show()           \n",
        "if color == \"lDDT\":\n",
        "  cf.plot_plddt_legend().show()\n",
        "\n",
        "# add confidence plots\n",
        "cf.plot_confidence(plddts[seq_num][model_num]*100, paes[seq_num][model_num], Ls=LS[seq_num]).show()"
      ],
      "metadata": {
        "cellView": "form",
        "id": "0TNhcwok8d_w"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "colab": {
      "name": "quickdemo_wAF2.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}